#include <stdio.h>
#include <vector>
#include <memory>
#include <stdexcept>
#include <string>
#include <cmath>
#include <torch/torch.h>

using namespace torch;
using namespace std;
namespace nn = nn;

//
//// 3x3 Convolution with padding
//nn::Conv2d conv3x3(int64_t in_planes, int64_t out_planes, int64_t stride = 1, int64_t padding = 1, bool bias = false) {
//    return nn::Conv2d(
//            nn::Conv2dOptions(in_planes, out_planes, 3).stride(stride).padding(padding).bias(false));
//}
//
//struct ConvBlockImpl : nn::Module {
//    nn::BatchNorm2d bn1, bn2, bn3;
//    nn::Conv2d conv1, conv2, conv3;
//    std::optional<nn::Sequential> downsample;
//
//    ConvBlockImpl(int64_t in_planes, int64_t out_planes)
//            : bn1(in_planes),
//              conv1(conv3x3(in_planes, out_planes / 2)),
//              bn2(out_planes / 2),
//              conv2(conv3x3(out_planes / 2, out_planes / 4)),
//              bn3(out_planes / 4),
//              conv3(conv3x3(out_planes / 4, out_planes / 4)) {
//
//        if (in_planes != out_planes) {
//            downsample = nn::Sequential(
//                    nn::BatchNorm2d(in_planes),
//                    nn::ReLU(true),
//                    nn::Conv2d(nn::Conv2dOptions(in_planes, out_planes, 1).stride(1).bias(false))
//            );
//        } else {
//            downsample = nullptr;
//        }
//
//        register_module("bn1", bn1);
//        register_module("conv1", conv1);
//        register_module("bn2", bn2);
//        register_module("conv2", conv2);
//        register_module("bn3", bn3);
//        register_module("conv3", conv3);
//        if (downsample.has_value()) {
//            register_module("downsample", downsample.value());
//        }
//    }
//
//    torch::Tensor forward(torch::Tensor x) {
//        auto residual = x;
//
//        auto out1 = conv1->forward(relu(bn1->forward(x)));
//        auto out2 = conv2->forward(relu(bn2->forward(out1)));
//        auto out3 = conv3->forward(relu(bn3->forward(out2)));
//
//        out3 = torch::cat({out1, out2, out3}, 1);
//
//        if (downsample.has_value()) {
//            residual = downsample.value()->forward(residual);
//        }
//
//        out3 += residual;
//
//        return out3;
//    }
//};
//
//TORCH_MODULE(ConvBlock);
//
//struct HourGlassImpl : nn::Module {
//    int64_t num_modules;
//    int64_t depth;
//    int64_t features;
//
//    HourGlassImpl(int64_t num_modules, int64_t depth, int64_t num_features)
//            : num_modules(num_modules), depth(depth), features(num_features) {
//        _generate_network(depth);
//    }
//
//    void _generate_network(int64_t level) {
//        register_module("b1_" + std::to_string(level), ConvBlock(features, features));
//        register_module("b2_" + std::to_string(level), ConvBlock(features, features));
//
//        if (level > 1) {
//            _generate_network(level - 1);
//        } else {
//            register_module("b2_plus_" + std::to_string(level), ConvBlock(features, features));
//        }
//
//        register_module("b3_" + std::to_string(level), ConvBlock(features, features));
//    }
//
//    torch::Tensor _forward(int64_t level, torch::Tensor inp) {
//        auto up1 = _modules["b1_" + std::to_string(level)]->as<ConvBlock>()->forward(inp);
//        auto low1 = avg_pool2d(inp, 2, 2);
//        auto low2 = _modules["b2_" + std::to_string(level)]->as<ConvBlock>()->forward(low1);
//
//        if (level > 1) {
//            low2 = _forward(level - 1, low2);
//        } else {
//            low2 = _modules["b2_plus_" + std::to_string(level)]->as<ConvBlock>()->forward(low2);
//        }
//
//        auto low3 = _modules["b3_" + std::to_string(level)]->as<ConvBlock>()->forward(low2);
//        auto up2 = interpolate(low3, upsample_options().scale_factor({2.0, 2.0}).mode(torch::kNearest));
//
//        return up1 + up2;
//    }
//
//    torch::Tensor forward(torch::Tensor x) {
//        return _forward(depth, x);
//    }
//};
//
//TORCH_MODULE(HourGlass);
//
//struct FANImpl : nn::Module {
//    int64_t num_modules;
//
//    FANImpl(int64_t num_modules = 1)
//            : num_modules(num_modules) {
//        conv1 = register_module("conv1", Conv2d(Conv2dOptions(3, 64, 7).stride(2).padding(3)));
//        bn1 = register_module("bn1", nn::BatchNorm2d(64));
//        conv2 = register_module("conv2", ConvBlock(64, 128));
//        conv3 = register_module("conv3", ConvBlock(128, 128));
//        conv4 = register_module("conv4", ConvBlock(128, 256));
//
//        for (int64_t hg_module = 0; hg_module < num_modules; ++hg_module) {
//            add_module("m" + std::to_string(hg_module), HourGlass(1, 4, 256));
//            add_module("top_m_" + std::to_string(hg_module), ConvBlock(256, 256));
//            add_module("conv_last" + std::to_string(hg_module),
//                       Conv2d(Conv2dOptions(256, 256, 1).stride(1).padding(0)));
//            add_module("bn_end" + std::to_string(hg_module), nn::BatchNorm2d(256));
//            add_module("l" + std::to_string(hg_module), Conv2d(Conv2dOptions(256, 68, 1).stride(1).padding(0)));
//
//            if (hg_module < num_modules - 1) {
//                add_module("bl" + std::to_string(hg_module), Conv2d(Conv2dOptions(256, 256, 1).stride(1).padding(0)));
//                add_module("al" + std::to_string(hg_module), Conv2d(Conv2dOptions(68, 256, 1).stride(1).padding(0)));
//            }
//        }
//    }
//
//    torch::Tensor forward(torch::Tensor x) {
//        x = relu(bn1->forward(conv1->forward(x)));
//        x = avg_pool2d(conv2->forward(x), 2, 2);
//        x = conv3->forward(x);
//        x = conv4->forward(x);
//
//        auto previous = x;
//        std::vector<torch::Tensor> outputs;
//
//        for (int64_t i = 0; i < num_modules; ++i) {
//            auto hg = _modules["m" + std::to_string(i)]->as<HourGlass>()->forward(previous);
//
//            auto ll = _modules["top_m_" + std::to_string(i)]->as<ConvBlock>()->forward(hg);
//            ll = relu(_modules["bn_end" + std::to_string(i)]->as<nn::BatchNorm2d>()->forward(
//                    _modules["conv_last" + std::to_string(i)]->as<Conv2d>()->forward(ll)
//            ));
//
//            auto tmp_out = _modules["l" + std::to_string(i)]->as<Conv2d>()->forward(ll);
//            outputs.push_back(tmp_out);
//
//            if (i < num_modules - 1) {
//                ll = _modules["bl" + std::to_string(i)]->as<Conv2d>()->forward(ll);
//                auto tmp_out_ = _modules["al" + std::to_string(i)]->as<Conv2d>()->forward(tmp_out);
//                previous = previous + ll + tmp_out_;
//            }
//        }
//
//        return torch::stack(outputs);
//    }
//
//    Conv2d conv1;
//    nn::BatchNorm2d bn1;
//    ConvBlock conv2, conv3, conv4;
//};
//
//TORCH_MODULE(FAN);
//
//struct BottleneckImpl : nn::Module {
//    static const int expansion = 4;
//    Conv2d conv1, conv2, conv3;
//    nn::BatchNorm2d bn1, bn2, bn3;
//    nn::ReLU relu;
//    nn::Sequential downsample;
//    int64_t stride;
//
//    BottleneckImpl(int64_t inplanes, int64_t planes, int64_t stride = 1, nn::Sequential downsample = nullptr)
//            : conv1(Conv2d(Conv2dOptions(inplanes, planes, 1).bias(false))),
//              bn1(planes),
//              conv2(Conv2d(Conv2dOptions(planes, planes, 3).stride(stride).padding(1).bias(false))),
//              bn2(planes),
//              conv3(Conv2d(Conv2dOptions(planes, planes * 4, 1).bias(false))),
//              bn3(planes * 4),
//              relu(true),
//              downsample(downsample),
//              stride(stride) {
//
//        register_module("conv1", conv1);
//        register_module("bn1", bn1);
//        register_module("conv2", conv2);
//        register_module("bn2", bn2);
//        register_module("conv3", conv3);
//        register_module("bn3", bn3);
//        register_module("relu", relu);
//        if (downsample != nullptr) {
//            register_module("downsample", downsample);
//        }
//    }
//
//    torch::Tensor forward(torch::Tensor x) {
//        auto residual = x;
//
//        auto out = relu(bn1->forward(conv1->forward(x)));
//        out = relu(bn2->forward(conv2->forward(out)));
//        out = bn3->forward(conv3->forward(out));
//
//        if (downsample != nullptr) {
//            residual = downsample->forward(x);
//        }
//
//        out += residual;
//        out = relu(out);
//
//        return out;
//    }
//};
//
//TORCH_MODULE(Bottleneck);
//
//struct ResNetDepthImpl : nn::Module {
//    int64_t inplanes;
//    Conv2d conv1;
//    nn::BatchNorm2d bn1;
//    nn::ReLU relu;
//    nn::MaxPool2d maxpool;
//    nn::Sequential layer1, layer2, layer3, layer4;
//    nn::AvgPool2d avgpool;
//    nn::Linear fc;
//
//    ResNetDepthImpl(int64_t num_classes = 68)
//            : inplanes(64),
//              conv1(Conv2d(Conv2dOptions(3 + 68, 64, 7).stride(2).padding(3).bias(false))),
//              bn1(64),
//              relu(true),
//              maxpool(3, 2, 1),
//              avgpool(7),
//              fc(512 * BottleneckImpl::expansion, num_classes) {
//
//        layer1 = _make_layer(64, 3);
//        layer2 = _make_layer(128, 8, 2);
//        layer3 = _make_layer(256, 36, 2);
//        layer4 = _make_layer(512, 3, 2);
//
//        register_module("conv1", conv1);
//        register_module("bn1", bn1);
//        register_module("relu", relu);
//        register_module("maxpool", maxpool);
//        register_module("layer1", layer1);
//        register_module("layer2", layer2);
//        register_module("layer3", layer3);
//        register_module("layer4", layer4);
//        register_module("avgpool", avgpool);
//        register_module("fc", fc);
//    }
//
//    nn::Sequential _make_layer(int64_t planes, int64_t blocks, int64_t stride = 1) {
//        nn::Sequential downsample = nullptr;
//        if (stride != 1 || inplanes != planes * BottleneckImpl::expansion) {
//            downsample = nn::Sequential(
//                    Conv2d(Conv2dOptions(inplanes, planes * BottleneckImpl::expansion, 1).stride(stride).bias(false)),
//                    nn::BatchNorm2d(planes * BottleneckImpl::expansion)
//            );
//        }
//
//        nn::Sequential layers;
//        layers->push_back(Bottleneck(inplanes, planes, stride, downsample));
//        inplanes = planes * BottleneckImpl::expansion;
//        for (int64_t i = 1; i < blocks; ++i) {
//            layers->push_back(Bottleneck(inplanes, planes));
//        }
//
//        return layers;
//    }
//
//    torch::Tensor forward(torch::Tensor x) {
//        x = relu(bn1->forward(conv1->forward(x)));
//        x = maxpool->forward(x);
//
//        x = layer1->forward(x);
//        x = layer2->forward(x);
//        x = layer3->forward(x);
//        x = layer4->forward(x);
//
//        x = avgpool->forward(x);
//        x = x.view({x.size(0), -1});
//        x = fc->forward(x);
//
//        return x;
//    }
//};
//
//TORCH_MODULE(ResNetDepth);

namespace coastal {

    class BaseModuleImpl : public torch::nn::Module {
    public:
        template<typename ModuleType>
        std::shared_ptr<ModuleType> get_module(const std::string &name) {
            auto modules = named_modules();
            std::shared_ptr<torch::nn::Module> *module = modules.find(name);
            if (module == nullptr) {
                throw std::runtime_error("Module not found");
            }
            return std::dynamic_pointer_cast<ModuleType>(*module);
        }
    };

}

class MyModel : public coastal::BaseModuleImpl {
    public:
        MyModel() {
            // 注册子模块
            auto conv1 = register_module("conv1", torch::nn::Conv2d(torch::nn::Conv2dOptions(1, 10, 5)));
            std::cout << "conv1: " << conv1 << std::endl;
        }

        torch::Tensor forward(torch::Tensor x) {
            return get_module<torch::nn::Conv2dImpl>(std::string("conv1"))->forward(x);
        }
};

int main() {
    try {
        // 创建模型实例
        auto model = std::make_shared<MyModel>();

        // 创建输入张量
        auto input = torch::randn({1, 1, 28, 28});

        // 调用 forward 方法
        auto output = model->forward(input);

        // 输出结果张量的尺寸
        std::cout << output.sizes() << std::endl;
    } catch (const std::exception &e) {
        std::cerr << "Exception: " << e.what() << std::endl;
    }

    return 0;
}